{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZjN_IJ8mhJ-4"
      },
      "source": [
        "##### Copyright 2023 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sY3Ffd83hK3b"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "03Pw58e6mTHI"
      },
      "source": [
        "# TF-NumPy Type Promotion"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l9nPKvxK-_pM"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/tf_numpy_type_promotion\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/tf_numpy_type_promotion.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/tf_numpy_type_promotion.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/tf_numpy_type_promotion.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uma-W5v__DYh"
      },
      "source": [
        "## Overview\n",
        "\n",
        "There are 4 options for type promotion in TensorFlow.\n",
        "\n",
        "- By default, TensorFlow raises errors instead of promoting types for mixed type operations.\n",
        "- Running `tf.numpy.experimental_enable_numpy_behavior()` switches TensorFlow to use [NumPy type promotion rules](https://www.tensorflow.org/guide/tf_numpy#type_promotion).\n",
        "- **This doc** describes two new options that will be available in TensorFlow 2.15 (or currently in `tf-nightly`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vMvEKDFOsau7"
      },
      "outputs": [],
      "source": [
        "!pip install -q tf_nightly"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a6hOFBfPsd3y"
      },
      "source": [
        " **Note**: `experimental_enable_numpy_behavior` changes the behavior of all of TensorFlow."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ob1HNwUmYR5b"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AJR558zjAZQu"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow.experimental.numpy as tnp\n",
        "\n",
        "print(\"Using TensorFlow version %s\" % tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M6tacoy0DU6e"
      },
      "source": [
        "### Enabling the new type promotion\n",
        "\n",
        "In order to use the [JAX-like type promotion](https://jax.readthedocs.io/en/latest/type_promotion.html) in TF-Numpy, specify either `'all'` or `'safe'` as the dtype conversion mode when enabling NumPy behavior for TensorFlow.\n",
        "\n",
        "This new system (with `dtype_conversion_mode=\"all\"`) is associative, commutative, and makes it easy to control what width of float you end up with (it doesn't automatically convert to wider floats). It does introduce some risks of overflows and precision loss, but `dtype_conversion_mode=\"safe\"` forces you to handle those cases explicitly. The two modes are explained more in detail in the [next section](#two_modes)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TfCyofpFDQxm"
      },
      "outputs": [],
      "source": [
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sEMXK8-ZWMun"
      },
      "source": [
        "<a name=\"two_modes\">\n",
        "\n",
        "## Two Modes : ALL mode vs SAFE mode\n",
        "\n",
        "In the new type promotion system, we introduce two modes: `ALL` mode and `SAFE` mode. `SAFE` mode is used to mitigate the concerns of \"risky\" promotions that can result in precision loss or bit-widening."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ULvTWj_KnHU"
      },
      "source": [
        "### Dtypes\n",
        "\n",
        "We will be using the following abbreviations for brevity.\n",
        "\n",
        "*   `b` means `tf.bool`\n",
        "*   `u8` means `tf.uint8`\n",
        "*   `i16` means `tf.int16`\n",
        "*   `i32` means `tf.int32`\n",
        "*   `bf16` means `tf.bfloat16`\n",
        "*   `f32` means `tf.float32`\n",
        "*   `f64` means `tf.float64`\n",
        "*   `i32*` means Python `int` or weakly-typed `i32`\n",
        "*   `f32*` means Python `float` or weakly-typed `f32`\n",
        "*   `c128*` means Python `complex` or weakly-typed `c128`\n",
        "\n",
        "The asterisk (*) denotes that the corresponding type is “weak” - such a dtype is temporarily inferred by the system, and could defer to other dtypes. This concept is explained more in detail [here](#weak_tensor)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hXZxLCkuzzq3"
      },
      "source": [
        "### Example of precision losing operations\n",
        "\n",
        "In the following example, `i32` + `f32` is allowed in `ALL` mode but\n",
        "not in `SAFE` mode due to the risk of precision loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y-yeIvstWStL"
      },
      "outputs": [],
      "source": [
        "# i32 + f32 returns a f32 result in ALL mode.\n",
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
        "a = tf.constant(10, dtype = tf.int32)\n",
        "b = tf.constant(5.0, dtype = tf.float32)\n",
        "a + b  # <tf.Tensor: shape=(), dtype=float32, numpy=15.0>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JNNmZow2WY3G"
      },
      "outputs": [],
      "source": [
        "# This promotion is not allowed in SAFE mode.\n",
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"safe\")\n",
        "a = tf.constant(10, dtype = tf.int32)\n",
        "b = tf.constant(5.0, dtype = tf.float32)\n",
        "try:\n",
        "  a + b\n",
        "except TypeError as e:\n",
        "   print(f'{type(e)}: {e}')  # TypeError: explicitly specify the dtype or switch to ALL mode."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f0x4Qhff0AKS"
      },
      "source": [
        "### Example of bit-widening operations\n",
        "\n",
        "In the following example, `i8` + `u32` is allowed in `ALL` mode but\n",
        "not in `SAFE` mode due to bit-widening, which means using more bits than the number of bits in the inputs. Note that the new type promotion semantics only allows necessary bit-widening."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Etbv-WoWzUXf"
      },
      "outputs": [],
      "source": [
        "# i8 + u32 returns an i64 result in ALL mode.\n",
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
        "a = tf.constant(10, dtype = tf.int8)\n",
        "b = tf.constant(5, dtype = tf.uint32)\n",
        "a + b"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yKRdvtvw0Lvt"
      },
      "outputs": [],
      "source": [
        "# This promotion is not allowed in SAFE mode.\n",
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"safe\")\n",
        "a = tf.constant(10, dtype = tf.int8)\n",
        "b = tf.constant(5, dtype = tf.uint32)\n",
        "try:\n",
        "  a + b\n",
        "except TypeError as e:\n",
        "   print(f'{type(e)}: {e}')  # TypeError: explicitly specify the dtype or switch to ALL mode."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yh2BwqUzH3C3"
      },
      "source": [
        "## A System Based on a Lattice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HHUnfTPiYVN5"
      },
      "source": [
        "### Type Promotion Lattice\n",
        "\n",
        "The new type promotion behavior is determined via the following type promotion lattice:\n",
        "\n",
        "![Type Promotion Lattice](https://tensorflow.org/guide/images/new_type_promotion/type_promotion_lattice.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QykluwRyDDle"
      },
      "source": [
        "More specifically, promotion between any two types is determined by finding the first common child of the two nodes (including the nodes themselves).\n",
        "\n",
        "For example, in the diagram above, the first common child of `i8` and `i32` is `i32` because the two nodes intersect for the first time at `i32` when following the direction of the arrows.\n",
        "\n",
        "Similarly as another example, the result promotion type between `u64` and `f16` would be `f16`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nthziRHaDAUY"
      },
      "source": [
        "<a name=\"promotion_table\">\n",
        "\n",
        "### Type Promotion Table\n",
        "\n",
        "Following the lattice generates the binary promotion table below:\n",
        "\n",
        "**Note**: `SAFE` mode disallows the highlighted cells. `ALL` mode allows all cases.\n",
        "\n",
        "![Type Promotion Table](https://tensorflow.org/guide/images/new_type_promotion/type_promotion_table.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TPDt5QTkucSC"
      },
      "source": [
        "## Advantages of The New Type Promotion\n",
        "\n",
        "We adopt a JAX-like lattice-based system for our new type promotion, which offers the following advantages:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NUS_b13nue1p"
      },
      "source": [
        "<a name=\"lattice_system_design\">\n",
        "\n",
        "#### Advantages of Lattice-Based System\n",
        "\n",
        "First, using a lattice-based system ensures three very important properties:\n",
        "\n",
        "*   Existence: There is a unique result promotion type for any combinations of types.\n",
        "*   Commutativity: `a + b = b + a`\n",
        "*   Associativity: `a + (b + c) = (a + b) = c`\n",
        "\n",
        "These three properties are critical for constructing a type promotion semantics that is consistent and predictable."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sz88hRR6uhls"
      },
      "source": [
        "#### Advantages of JAX-like Lattice System\n",
        "\n",
        "Another crucial advantage of the JAX-like lattice system is that outside unsigned ints, it avoids all wider-than-necessary promotions. This means you cannot get 64-bit results without 64-bit inputs. This is especially beneficial for working on accelerators as it avoids unnecessary 64-bit values, which was frequent in the old type promotion."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rlylb7ieOVbJ"
      },
      "source": [
        "However, this comes with a trade-off: mixed float/integer promotion is very prone to precision loss. For instance, in the example below, `i64` + `f16` results in promoting `i64` to `f16`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "abqIkV02OXEF"
      },
      "outputs": [],
      "source": [
        "# The first input is promoted to f16 in ALL mode.\n",
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
        "tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16)  # <tf.Tensor: shape=(), dtype=float16, numpy=4.2>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mYnh1gZdObfI"
      },
      "source": [
        "To migitage this concern, we introduced a `SAFE` mode that will disallow these \"risky\" promotions.\n",
        "\n",
        "**Note**: To learn more about the design considerations in constructing the lattice system, please refer to the [Design of Type Promotion Semantics for JAX](https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gAc7LFV0S2dP"
      },
      "source": [
        "<a name=\"weak_tensor\">\n",
        "\n",
        "## WeakTensor"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "olQ2gsFlS9BH"
      },
      "source": [
        "### Overview\n",
        "\n",
        "*Weak tensors* are Tensors that are \"weakly typed\", similar to a [concept in JAX](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax).\n",
        "\n",
        "`WeakTensor`'s dtype is temporarily inferred by the system, and could defer to other dtypes. This concept is introduced in the new type promotion to prevent unwanted type promotion within binary operations between TF values and values with no explicitly user-specified type, such as Python scalar literals."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MYmoFIqZTFtw"
      },
      "source": [
        "For instance, in the example below, `tf.constant(1.2)` is considered \"weak\" because it doesn't have a specific dtype. Therefore, `tf.constant(1.2)` defers to the type of `tf.constant(3.1, tf.float16)`, resulting in a `f16` output."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eSBv_mzyTE97"
      },
      "outputs": [],
      "source": [
        "tf.constant(1.2) + tf.constant(3.1, tf.float16)  # <tf.Tensor: shape=(), dtype=float16, numpy=4.3>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KxuqBIFuTm5Z"
      },
      "source": [
        "### WeakTensor Construction\n",
        "\n",
        "WeakTensors are created if you create a tensor without specifying a dtype the result is a WeakTensor. You can check whether a Tensor is \"weak\" or not by checking the weak attribute at the end of the Tensor's string representation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7UmunnJ8True3"
      },
      "source": [
        "**First Case**: When `tf.constant` is called with an input with no user-specified dtype."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fLEtMluNTsI5"
      },
      "outputs": [],
      "source": [
        "tf.constant(5)  # <tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZQX6MBWHTt__"
      },
      "outputs": [],
      "source": [
        "tf.constant([5.0, 10.0, 3])  # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10.,  3.], dtype=float32), weak=True>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ftsKSC5BTweP"
      },
      "outputs": [],
      "source": [
        "# A normal Tensor is created when dtype arg is specified.\n",
        "tf.constant(5, tf.int32)  # <tf.Tensor: shape=(), dtype=int32, numpy=5>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RqhoRy5iTyag"
      },
      "source": [
        "**Second Case**: When an input with no user-specified dtype is passed into a [WeakTensor-supporting API](#weak_tensor_apis)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DuwpgoQJTzE-"
      },
      "outputs": [],
      "source": [
        "tf.math.abs([100.0, 4.0])  # <tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UTcoR1xvR39k"
      },
      "source": [
        "##Effects of turning on the new type promotion\n",
        "\n",
        "Below is a non-exhaustive list of changes that result from turning on the new type promotion.\n",
        "\n",
        "*   More consistent and predictable promotion results.\n",
        "*   Reduced risk of bit-widening.\n",
        "*   `tf.Tensor` mathematical dunder methods use new type promotion.\n",
        "*   `tf.constant` can return `WeakTensor`.\n",
        "*   `tf.constant` allows implicit conversions when a Tensor input with a dtype different from the `dtype` arg is passed in.\n",
        "*   `tf.Variable` in-place ops (`assign`, `assign-add`, `assign-sub`) allow implicit conversions.\n",
        "*   `tnp.array(1)` and `tnp.array(1.0)` returns 32-bit WeakTensor.\n",
        "*   `WeakTensor`s will be created and used for [WeakTensor-supporting unary and binary API](#weak_tensor_apis)'s.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KyvonwYcsFX2"
      },
      "source": [
        "### More consistent and predictable promotion results\n",
        "\n",
        "Using a [lattice-based system](#lattice_system_design) allows the new type promotion to produce consistent and predictable type promotion results."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q0Z1njfb7lRa"
      },
      "source": [
        "#### Old Type Promotion\n",
        "\n",
        "Changing the order of operations produces inconsistent results using old type promotion."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M1Ca9v4m7z8e"
      },
      "outputs": [],
      "source": [
        "# Setup\n",
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"legacy\")\n",
        "a = np.array(1, dtype=np.int8)\n",
        "b = tf.constant(1)\n",
        "c = np.array(1, dtype=np.float16)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WwhTzJ-a4rTc"
      },
      "outputs": [],
      "source": [
        "# (a + b) + c throws an InvalidArgumentError.\n",
        "try:\n",
        "  tf.add(tf.add(a, b), c)\n",
        "except tf.errors.InvalidArgumentError as e:\n",
        "  print(f'{type(e)}: {e}')  # InvalidArgumentError"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d3qDgVYn7ezT"
      },
      "outputs": [],
      "source": [
        "# (b + a) + c returns an i32 result.\n",
        "tf.add(tf.add(b, a), c)  # <tf.Tensor: shape=(), dtype=int32, numpy=3>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YMH1skEs7oI5"
      },
      "source": [
        "#### New Type Promotion\n",
        "\n",
        "New type promotion produces consistent results regardless of the order."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BOHyJJ8z8uCN"
      },
      "outputs": [],
      "source": [
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
        "a = np.array(1, dtype=np.int8)\n",
        "b = tf.constant(1)\n",
        "c = np.array(1, dtype=np.float16)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZUKU70jf7E1l"
      },
      "outputs": [],
      "source": [
        "# (a + b) + c returns a f16 result.\n",
        "tf.add(tf.add(a, b), c)  # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YOEycjFx7qDn"
      },
      "outputs": [],
      "source": [
        "# (b + a) + c also returns a f16 result.\n",
        "tf.add(tf.add(b, a), c)  # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FpGMkm6aJsn6"
      },
      "source": [
        "### Reduced risk of bit-widening"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JxV2AL-U9Grg"
      },
      "source": [
        "#### Old Type Promotion\n",
        "\n",
        "Old type promotion often resulted in 64-bit results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7L1pxyvn9MlP"
      },
      "outputs": [],
      "source": [
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"legacy\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zMJVFdWf4XHp"
      },
      "outputs": [],
      "source": [
        "np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50)  # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fBhUH_wD9Is7"
      },
      "source": [
        "#### New Type Promotion\n",
        "\n",
        "New type promotion returns results with minimal number of bits necessary."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aJsj2ZyI9T9Y"
      },
      "outputs": [],
      "source": [
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jj0N_Plp4X9l"
      },
      "outputs": [],
      "source": [
        "np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50)  # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yKUx7xe-KZ5O"
      },
      "source": [
        "### tf.Tensor mathematical dunder methods\n",
        "\n",
        "All `tf.Tensor` mathematical dunder methods will follow the new type promotion."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2c3icBUX4wNl"
      },
      "outputs": [],
      "source": [
        "-tf.constant(5)  # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ydJHQjid45s7"
      },
      "outputs": [],
      "source": [
        "tf.constant(5, tf.int16) - tf.constant(1, tf.float32)  # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pLbIjIvbKqcU"
      },
      "source": [
        "### tf.Variable in-place ops\n",
        "\n",
        "Implicit conversions will be allowed in `tf.Variable` in-place ops.\n",
        "\n",
        "**Note**: Any promotion that results in a dtype that is different from the variable's original dtype will be not allowed. This is because `tf.Variable` cannot change its dtype."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QsXhyK1h-i5S"
      },
      "outputs": [],
      "source": [
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
        "a = tf.Variable(10, tf.int32)\n",
        "a.assign_add(tf.constant(5, tf.int16))  # <tf.Variable shape=() dtype=int32, numpy=15>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PiA4H-otLDit"
      },
      "source": [
        "### tf.constant implicit conversions\n",
        "\n",
        "In the old type promotion, `tf.constant` required an input Tensor to have the same dtype as the dtype argument. However, in the new type promotion, we implicitly convert Tensor to the specified dtype."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ArrQ9Dj0_OR8"
      },
      "outputs": [],
      "source": [
        "tnp.experimental_enable_numpy_behavior(dtype_conversion_mode=\"all\")\n",
        "a = tf.constant(10, tf.int16)\n",
        "tf.constant(a, tf.float32)  # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WAcK_-XnLWaP"
      },
      "source": [
        "### TF-NumPy Array\n",
        "\n",
        "`tnp.array` defaults to `i32*` and `f32*` for python inputs using the new type promotion."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K1pZnYNh_ahm"
      },
      "outputs": [],
      "source": [
        "tnp.array(1)  # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QoQl2PYP_fMT"
      },
      "outputs": [],
      "source": [
        "tnp.array(1.0)  # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wK5DpQ3Pz3k5"
      },
      "source": [
        "##Input Type Inference\n",
        "\n",
        "This is how different inputs' types are inferred in the new type promotion.\n",
        "\n",
        "\n",
        "*   `tf.Tensor`: Since `tf.Tensor` has a dtype property, we don't do further inference.\n",
        "*   NumPy types: This includes types like `np.array(1)`, `np.int16(1)`, and `np.float`. Since NumPy inputs also have a dtype property, we take the dtype property as the result inference type. Note that NumPy defaults to `i64` and `f64`.\n",
        "*   Python scalars/Nested types: This includes types like `1`, `[1, 2, 3]`, and `(1.0, 2.0)`.\n",
        "   *   Python `int` is inferred as `i32*`.\n",
        "   *   Python `float` is inferred as `f32*`.\n",
        "   *   Python `complex` is inferred as `c128*`.\n",
        "*  If the input doesn't fall into any of the above categories but has a dtype property, we take the dtype property as the result inference type."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g_SPfalfSPgg"
      },
      "source": [
        "# Further Reading\n",
        "\n",
        "The new type promotion closely resembles JAX-NumPy's type promotion. If you want to know more details about the new type promotion and the design choices, check out the resources below.\n",
        "\n",
        "*  [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html)\n",
        "*  [Design of Type Promotion Semantics for JAX](https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html)\n",
        "*  [Old TF-NumPy Promotion Semantics](https://www.tensorflow.org/guide/tf_numpy#type_promotion)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qg5xBbImT31S"
      },
      "source": [
        "# References"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gjB0CVhVXBfW"
      },
      "source": [
        "<a name=\"weak_tensor_apis\">\n",
        "\n",
        "## WeakTensor-supporting APIs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_GVbqlN9aBS2"
      },
      "source": [
        "Below is a list of APIs that supports `WeakTensor`.\n",
        "\n",
        "For an unary op, this means that if an input with no user-specified type is passed in, it will return a `WeakTensor`.\n",
        "\n",
        "For a binary op, it will follow the promotion table [here](#promotion_table). It may or may not return a `WeakTensor` depending on the promotion result of the two inputs.\n",
        "\n",
        "**Note**: All mathematical operations (`+`, `-`, `*`, ...) are supported."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gi-G68Z8WN2P"
      },
      "source": [
        "* `tf.bitwise.invert`\n",
        "* `tf.clip_by_value`\n",
        "* `tf.debugging.check_numerics`\n",
        "* `tf.expand_dims`\n",
        "* `tf.identity`\n",
        "* `tf.image.adjust_brightness`\n",
        "* `tf.image.adjust_gamma`\n",
        "* `tf.image.extract_patches`\n",
        "* `tf.image.random_brightness`\n",
        "* `tf.image.stateless_random_brightness`\n",
        "* `tf.linalg.diag`\n",
        "* `tf.linalg.diag_part`\n",
        "* `tf.linalg.matmul`\n",
        "* `tf.linalg.matrix_transpose`\n",
        "* `tf.linalg.tensor_diag_part`\n",
        "* `tf.linalg.trace`\n",
        "* `tf.math.abs`\n",
        "* `tf.math.acos`\n",
        "* `tf.math.acosh`\n",
        "* `tf.math.add`\n",
        "* `tf.math.angle`\n",
        "* `tf.math.asin`\n",
        "* `tf.math.asinh`\n",
        "* `tf.math.atan`\n",
        "* `tf.math.atanh`\n",
        "* `tf.math.ceil`\n",
        "* `tf.math.conj`\n",
        "* `tf.math.cos`\n",
        "* `tf.math.cosh`\n",
        "* `tf.math.digamma`\n",
        "* `tf.math.divide_no_nan`\n",
        "* `tf.math.divide`\n",
        "* `tf.math.erf`\n",
        "* `tf.math.erfc`\n",
        "* `tf.math.erfcinv`\n",
        "* `tf.math.erfinv`\n",
        "* `tf.math.exp`\n",
        "* `tf.math.expm1`\n",
        "* `tf.math.floor`\n",
        "* `tf.math.floordiv`\n",
        "* `tf.math.floormod`\n",
        "* `tf.math.imag`\n",
        "* `tf.math.lgamma`\n",
        "* `tf.math.log1p`\n",
        "* `tf.math.log_sigmoid`\n",
        "* `tf.math.log`\n",
        "* `tf.math.multiply_no_nan`\n",
        "* `tf.math.multiply`\n",
        "* `tf.math.ndtri`\n",
        "* `tf.math.negative`\n",
        "* `tf.math.pow`\n",
        "* `tf.math.real`\n",
        "* `tf.math.real`\n",
        "* `tf.math.reciprocal_no_nan`\n",
        "* `tf.math.reciprocal`\n",
        "* `tf.math.reduce_euclidean_norm`\n",
        "* `tf.math.reduce_logsumexp`\n",
        "* `tf.math.reduce_max`\n",
        "* `tf.math.reduce_mean`\n",
        "* `tf.math.reduce_min`\n",
        "* `tf.math.reduce_prod`\n",
        "* `tf.math.reduce_std`\n",
        "* `tf.math.reduce_sum`\n",
        "* `tf.math.reduce_variance`\n",
        "* `tf.math.rint`\n",
        "* `tf.math.round`\n",
        "* `tf.math.rsqrt`\n",
        "* `tf.math.scalar_mul`\n",
        "* `tf.math.sigmoid`\n",
        "* `tf.math.sign`\n",
        "* `tf.math.sin`\n",
        "* `tf.math.sinh`\n",
        "* `tf.math.softplus`\n",
        "* `tf.math.special.bessel_i0`\n",
        "* `tf.math.special.bessel_i0e`\n",
        "* `tf.math.special.bessel_i1`\n",
        "* `tf.math.special.bessel_i1e`\n",
        "* `tf.math.special.bessel_j0`\n",
        "* `tf.math.special.bessel_j1`\n",
        "* `tf.math.special.bessel_k0`\n",
        "* `tf.math.special.bessel_k0e`\n",
        "* `tf.math.special.bessel_k1`\n",
        "* `tf.math.special.bessel_k1e`\n",
        "* `tf.math.special.bessel_y0`\n",
        "* `tf.math.special.bessel_y1`\n",
        "* `tf.math.special.dawsn`\n",
        "* `tf.math.special.expint`\n",
        "* `tf.math.special.fresnel_cos`\n",
        "* `tf.math.special.fresnel_sin`\n",
        "* `tf.math.special.spence`\n",
        "* `tf.math.sqrt`\n",
        "* `tf.math.square`\n",
        "* `tf.math.subtract`\n",
        "* `tf.math.tan`\n",
        "* `tf.math.tanh`\n",
        "* `tf.nn.depth_to_space`\n",
        "* `tf.nn.elu`\n",
        "* `tf.nn.gelu`\n",
        "* `tf.nn.leaky_relu`\n",
        "* `tf.nn.log_softmax`\n",
        "* `tf.nn.relu6`\n",
        "* `tf.nn.relu`\n",
        "* `tf.nn.selu`\n",
        "* `tf.nn.softsign`\n",
        "* `tf.nn.space_to_depth`\n",
        "* `tf.nn.swish`\n",
        "* `tf.ones_like`\n",
        "* `tf.realdiv`\n",
        "* `tf.reshape`\n",
        "* `tf.squeeze`\n",
        "* `tf.stop_gradient`\n",
        "* `tf.transpose`\n",
        "* `tf.truncatediv`\n",
        "* `tf.truncatemod`\n",
        "* `tf.zeros_like`\n",
        "* `tf.experimental.numpy.abs`\n",
        "* `tf.experimental.numpy.absolute`\n",
        "* `tf.experimental.numpy.amax`\n",
        "* `tf.experimental.numpy.amin`\n",
        "* `tf.experimental.numpy.angle`\n",
        "* `tf.experimental.numpy.arange`\n",
        "* `tf.experimental.numpy.arccos`\n",
        "* `tf.experimental.numpy.arccosh`\n",
        "* `tf.experimental.numpy.arcsin`\n",
        "* `tf.experimental.numpy.arcsinh`\n",
        "* `tf.experimental.numpy.arctan`\n",
        "* `tf.experimental.numpy.arctanh`\n",
        "* `tf.experimental.numpy.around`\n",
        "* `tf.experimental.numpy.array`\n",
        "* `tf.experimental.numpy.asanyarray`\n",
        "* `tf.experimental.numpy.asarray`\n",
        "* `tf.experimental.numpy.ascontiguousarray`\n",
        "* `tf.experimental.numpy.average`\n",
        "* `tf.experimental.numpy.bitwise_not`\n",
        "* `tf.experimental.numpy.cbrt`\n",
        "* `tf.experimental.numpy.ceil`\n",
        "* `tf.experimental.numpy.conj`\n",
        "* `tf.experimental.numpy.conjugate`\n",
        "* `tf.experimental.numpy.copy`\n",
        "* `tf.experimental.numpy.cos`\n",
        "* `tf.experimental.numpy.cosh`\n",
        "* `tf.experimental.numpy.cumprod`\n",
        "* `tf.experimental.numpy.cumsum`\n",
        "* `tf.experimental.numpy.deg2rad`\n",
        "* `tf.experimental.numpy.diag`\n",
        "* `tf.experimental.numpy.diagflat`\n",
        "* `tf.experimental.numpy.diagonal`\n",
        "* `tf.experimental.numpy.diff`\n",
        "* `tf.experimental.numpy.empty_like`\n",
        "* `tf.experimental.numpy.exp2`\n",
        "* `tf.experimental.numpy.exp`\n",
        "* `tf.experimental.numpy.expand_dims`\n",
        "* `tf.experimental.numpy.expm1`\n",
        "* `tf.experimental.numpy.fabs`\n",
        "* `tf.experimental.numpy.fix`\n",
        "* `tf.experimental.numpy.flatten`\n",
        "* `tf.experimental.numpy.flip`\n",
        "* `tf.experimental.numpy.fliplr`\n",
        "* `tf.experimental.numpy.flipud`\n",
        "* `tf.experimental.numpy.floor`\n",
        "* `tf.experimental.numpy.full_like`\n",
        "* `tf.experimental.numpy.imag`\n",
        "* `tf.experimental.numpy.log10`\n",
        "* `tf.experimental.numpy.log1p`\n",
        "* `tf.experimental.numpy.log2`\n",
        "* `tf.experimental.numpy.log`\n",
        "* `tf.experimental.numpy.max`\n",
        "* `tf.experimental.numpy.mean`\n",
        "* `tf.experimental.numpy.min`\n",
        "* `tf.experimental.numpy.moveaxis`\n",
        "* `tf.experimental.numpy.nanmean`\n",
        "* `tf.experimental.numpy.negative`\n",
        "* `tf.experimental.numpy.ones_like`\n",
        "* `tf.experimental.numpy.positive`\n",
        "* `tf.experimental.numpy.prod`\n",
        "* `tf.experimental.numpy.rad2deg`\n",
        "* `tf.experimental.numpy.ravel`\n",
        "* `tf.experimental.numpy.real`\n",
        "* `tf.experimental.numpy.reciprocal`\n",
        "* `tf.experimental.numpy.repeat`\n",
        "* `tf.experimental.numpy.reshape`\n",
        "* `tf.experimental.numpy.rot90`\n",
        "* `tf.experimental.numpy.round`\n",
        "* `tf.experimental.numpy.signbit`\n",
        "* `tf.experimental.numpy.sin`\n",
        "* `tf.experimental.numpy.sinc`\n",
        "* `tf.experimental.numpy.sinh`\n",
        "* `tf.experimental.numpy.sort`\n",
        "* `tf.experimental.numpy.sqrt`\n",
        "* `tf.experimental.numpy.square`\n",
        "* `tf.experimental.numpy.squeeze`\n",
        "* `tf.experimental.numpy.std`\n",
        "* `tf.experimental.numpy.sum`\n",
        "* `tf.experimental.numpy.swapaxes`\n",
        "* `tf.experimental.numpy.tan`\n",
        "* `tf.experimental.numpy.tanh`\n",
        "* `tf.experimental.numpy.trace`\n",
        "* `tf.experimental.numpy.transpose`\n",
        "* `tf.experimental.numpy.triu`\n",
        "* `tf.experimental.numpy.vander`\n",
        "* `tf.experimental.numpy.var`\n",
        "* `tf.experimental.numpy.zeros_like`"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "tf_numpy_type_promotion.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
